#!/usr/bin/env python3
"""
generate_flip_counts.py

Generate a per‑link array of tick‑flip counts for an L×L periodic lattice.  This
script generalises the Volume‑4 flip‑count simulator to arbitrary lattice size.

The tick‑flip algebra is implemented directly in this script using the
definitions from the private ``ar‑operator‑core`` repository.  For each link
in the lattice a short random walk over the tick‑flip operator algebra is
performed and the number of changes at a link‑dependent context index is
counted.  The resulting array of length ``2*L*L`` is saved to the specified
output path.

Example usage (to generate a 6×6 lattice with seed 0 and write to
``data/flip_counts.npy``):

    python generate_flip_counts.py --lattice-size 6 --seed 0 --output data/flip_counts.npy

This script intentionally avoids any placeholder random counts.  It uses the
fractal‑pivot tick‑flip operators (F, S, X, C, Φ) exactly as in the original
Volume‑4 flip‑count simulator.
"""

import argparse
import os
import numpy as np
from dataclasses import dataclass


@dataclass(frozen=True)
class TickState:
    """Container for a tick distribution and its context depth N."""
    distribution: np.ndarray
    N: int

    def __post_init__(self):
        if not isinstance(self.distribution, np.ndarray):
            raise TypeError("distribution must be a numpy.ndarray")
        if self.distribution.ndim != 1:
            raise ValueError("distribution must be a 1D array")
        expected_len = 2 * self.N + 1
        if self.distribution.size != expected_len:
            raise ValueError(
                f"distribution length must be 2*N+1 ({expected_len}), got {self.distribution.size}"
            )


def renewal(state: TickState) -> TickState:
    """Apply the Renewal (F) operator: move mass outward."""
    dist = state.distribution
    L = dist.size
    new = np.zeros_like(dist)
    if L > 1:
        new[0] = dist[0] + dist[1]
    else:
        new[0] = dist[0]
    if L > 2:
        new[1 : L - 1] = dist[2:]
    if L > 0:
        new[L - 1] = 0.0
    return TickState(new, state.N)


def F(state: TickState) -> TickState:
    return renewal(state)


def S(state: TickState) -> TickState:
    """Apply the Sink operator: move mass inward."""
    dist = state.distribution
    L = dist.size
    new = np.zeros_like(dist)
    if L > 2:
        new[1 : L - 1] = dist[0 : L - 2]
    if L > 1:
        new[L - 1] = dist[L - 1] + dist[L - 2]
    else:
        new[0] = dist[0]
    return TickState(new, state.N)


def X(state: TickState) -> TickState:
    """Apply the Distinction operator: reverse the distribution."""
    return TickState(state.distribution[::-1].copy(), state.N)


def C(state: TickState) -> TickState:
    """Apply the Sync operator: average distribution with its reverse."""
    dist = state.distribution
    rev = dist[::-1]
    return TickState(0.5 * (dist + rev), state.N)


def Phi(state: TickState) -> TickState:
    """Apply the Frame coupling operator: Φ = C ∘ X."""
    return C(X(state))


def build_default_lattice(size: int, boundary: str = "periodic") -> np.ndarray:
    """
    Construct a 2D periodic lattice of side length `size`.  Returns an array
    of ((x,y), mu) tuples, length 2*size*size.
    """
    links = []
    directions = [(1, 0), (0, 1)]  # mu=0: +x, mu=1: +y
    for x in range(size):
        for y in range(size):
            for mu, (dx, dy) in enumerate(directions):
                nx, ny = x + dx, y + dy
                if boundary == "periodic":
                    nx %= size
                    ny %= size
                else:
                    if not (0 <= nx < size and 0 <= ny < size):
                        continue
                links.append(((x, y), mu))
    return np.array(links, dtype=object)


def count_flips_on_link(link: tuple, sim_params: dict) -> int:
    """
    For a single link, perform a short random walk over the tick‑flip operators
    and count changes at a link-specific watch index.
    """
    N = sim_params["N"]
    steps = sim_params["steps_per_link"]
    seed = sim_params.get("seed", None)
    rng = np.random.default_rng(seed)

    # Initialize delta-spike distribution
    dist0 = np.zeros(2 * N + 1)
    centre = N
    dist0[centre] = 1.0
    state = TickState(dist0, N)

    # Map link to watch index
    ((x, y), mu) = link
    watch_idx = (x + y + mu) % (2 * N + 1)

    flip_count = 0
    ops = [F, S, X, C, Phi]
    for _ in range(steps):
        for op in ops:
            new_state = op(state)
            # count if the watch index changed
            if not np.isclose(new_state.distribution[watch_idx], state.distribution[watch_idx]):
                flip_count += 1
            state = new_state
    return flip_count


def generate_flip_counts(L: int, seed: int, N: int, steps_per_link: int) -> np.ndarray:
    """
    Generate flip counts for all links on an L×L lattice via tick‑flip operator walks.
    """
    lattice = build_default_lattice(L, boundary="periodic")
    flip_counts = np.zeros(len(lattice), dtype=int)
    sim_params = {"N": N, "steps_per_link": steps_per_link, "seed": seed}
    for idx, link in enumerate(lattice):
        flip_counts[idx] = count_flips_on_link(link, sim_params)
    return flip_counts


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Generate tick‑flip counts for an L×L lattice"
    )
    parser.add_argument(
        "--lattice-size",
        "-L",
        type=int,
        default=4,
        help="Lattice side length (e.g. 6)",
    )
    parser.add_argument(
        "--seed",
        "-s",
        type=int,
        default=None,
        help="Random seed for reproducibility",
    )
    parser.add_argument(
        "--context-depth",
        "-N",
        type=int,
        default=2,
        help="Context depth (half‑range of distribution)",
    )
    parser.add_argument(
        "--steps-per-link",
        "-t",
        type=int,
        default=1000,
        help="Number of operator sequences per link",
    )
    parser.add_argument(
        "--output",
        "-o",
        default="data/flip_counts.npy",
        help="Output .npy file path",
    )
    args = parser.parse_args()

    counts = generate_flip_counts(
        L=args.lattice_size,
        seed=args.seed,
        N=args.context_depth,
        steps_per_link=args.steps_per_link,
    )

    out_dir = os.path.dirname(args.output)
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)
    # assert shape correctness
    expected_len = 2 * args.lattice_size * args.lattice_size
    if counts.size != expected_len:
        raise RuntimeError(
            f"Generated flip counts have length {counts.size}, expected {expected_len}"
        )
    np.save(args.output, counts)
    print(
        f"Saved flip counts for {args.lattice_size}×{args.lattice_size} lattice to {args.output} (length {len(counts)})"
    )


if __name__ == "__main__":
    main()